#!/bin/bash
set -e

export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

python3 MaxText/train.py MaxText/configs/base.yml run_name=maxtext-single-slice-201 dataset_path=gs://your-prefix-artifact-repository/datasets base_output_directory=gs://your-prefix-artifact-repository/runs steps=150 log_period=50 per_device_batch_size=6 global_parameter_scale=8 enable_checkpointing=false enable_profiler=false remat_policy=full dcn_data_parallelism=1 ici_fsdp_parallelism=16
